Skip to content

Conversation

@Rian354
Copy link

@Rian354 Rian354 commented Dec 8, 2025

PR for MedLink bounty

Tests:
To run the MedLink unit tests, from the project root run:

pytest tests/core/test_medlink.py (locally, 3 passed & 1 warning)

Model Implementation:

  • Implemented the MedLink retrieval model on top of the current BaseModel / dataset API.
  • Added unit tests with small synthetic data for MedLink.
  • Added a Jupyter notebook that trains and evaluates MedLink on the MIMIC-III demo dataset.

Additions to "pyhealth/models/medlink/model.py":

  • BaseModel-compatible "MedLink" class that takes a task-generated dataset (e.g., "SampleDataset" from "set_task") and "feature_keys".
  • Vocabulary construction from the underlying task dataset using "dataset.get_all_tokens(...)" for queries and documents.
  • Query and corpus encoders ("encode_queries", "encode_corpus") that produce sparse multi-hot representations.
  • BM25-style scoring in "compute_scores", compatible with the IR-format data produced by the MedLink utilities.
  • Combined retrieval and prediction loss in forward / get_loss, returning a scalar loss for training.

Other changes:

  • Extended SampleDataset w/ get_all_tokens(key: str) to collect unique tokens across samples, used by MedLink for vocabulary building.
  • Implemented BM25 and IR helpers in the pyhealth.models.medlink package:
    • BM25Okapi
    • convert_to_ir_format, tvt_split
    • generate_candidates, filter_by_candidates
    • get_bm25_hard_negatives, get_train_dataloader, get_eval_dataloader
  • Exported MedLink via pyhealth.models.init, so users can do: from pyhealth.models import MedLink

Added examples/medlink_mimic3.ipynb, a runnable notebook that:

Loads the MIMIC-III demo dataset via MIMIC3Dataset.

Defines a patient linkage task to generate query–candidate pairs.

Uses the MedLink helpers to build IR-format data and PyTorch dataloaders.

Trains and evaluates MedLink and reports ranking metrics.

Locally ran:

examples/medlink_mimic3.ipynb runs end-to-end on the MIMIC-III demo dataset.

The notebook includes a note on how to run the MedLink unit tests from project root.

Files to review:

pyhealth/datasets/sample_dataset.py – SampleDataset.get_all_tokens helper for vocabulary construction.

pyhealth/models/medlink/model.py – core MedLink model implementation.

pyhealth/models/medlink/bm25.py – BM25Okapi implementation used in the retrieval pipeline.

pyhealth/models/medlink/utils.py – IR-format conversion, TVT split, candidate generation, dataloaders.

pyhealth/models/init.py – export of MedLink.

tests/core/test_medlink.py – synthetic unit tests for MedLink (forward pass, encoders, score shapes).

examples/medlink_mimic3.ipynb – Jupyter notebook for training and evaluating MedLink on the MIMIC-III demo dataset.

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll probably add more comments as I have more time to dig deeper into this, but nice first attempt at actually a pretty hard bounty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some quick thoughts that:

  • Can we move the medlink task into the pyhealth.tasks module too? I actually think it'd be really helpful also to further have detailed documentation surrounding the query/document identifiers. It'd be good to link it up with the original paper's task of mapping records to a master known patient record.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be nice to have it in the docs/ as that'll actually be a pretty nice to have for anyone working on record linkage problems.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also try to see if we can't build new processors here to pass to the MedLink model.

Actually, I think the sequence processors should have built-in vocabularies here. But, it would be nice to update the EmbeddingModel to better support things like initialized Glove vectors or just use randomly initailized embeddings for now. This way medlink can be better integrated with the rest of PyHealth, and I think it'd be a nice lesson in replicating the original implementation. (A lot of the techniques are pretty relevant to clinical predictive modeling, so I think it's a good learning exercise).

Example of a PR working with the processors instead of the previous old PyHealth tokenizer approach here: https://github.com/sunlabuiuc/PyHealth/pull/610/files

Glove vectors from the original implementation: https://github.com/zzachw/MedLink here.

@Logiquo Logiquo added component: model Contribute a new model to PyHealth bounty Please see the bounty list in PyHealth Discord Server labels Dec 18, 2025
@Rian354
Copy link
Author

Rian354 commented Dec 22, 2025

Summary

Updated with content that completes the Medlink bounty via adding the processor native implementation as requested (integrated w/ Pyhealth 2.x processors), unit tests with synthetic data, and an end to end MIMIC-III demo. It also has optional pretrained embedding initialization (glove style vectors) to better match the original Medlink implementation.

Changes

1) MedLink model implementation (processor native)

  • Refactored pyhealth/models/medlink/model.py to use SequenceProcessor vocab over the legacy tokenizer approach.
  • Builds a shared vocabulary across query and document processors (conditions and d_conditions) and remaps indices so embeddings are consistent between both.
  • Contains randomly initialized embeddings by default, and optional pretrained embedding initialization (glove style) when given.

2) Pretrained embeddings support

  • Updated EmbeddingModel to support loading pretrained embedding vectors from a text file and optionally freezing them.
  • This allows for reproducing the original Medlink workflow (optional) while keeping default behavior as just random initialization.

3) Unit tests

  • Added/updated tests/core/test_medlink.py with synthetic data to validate:
    • forward pass returns a scalar loss
    • query/corpus encoding shapes are consistent w/ vocab
    • score matrix shape matches (num_queries, num_docs)
  • Tests are lightweight and do not require MIMIC.

4) End-to-end example notebook

  • Added examples/medlink_mimic3.ipynb showing:
    • loading MIMIC-III demo data
    • building patient linkage samples
    • training Medlink for a few epochs
    • evaluating with ranking metrics

Bounty requirement coverage from feedback

  • Implement MedLink model (using PyHealth 2.x patterns; processor-driven)
  • Add a unit test with pseudo data that runs a forward pass (tests/core/test_medlink.py)
  • Add an example notebook demonstrating training on MIMIC-III (examples/medlink_mimic3.ipynb)

How to run tests

From repo root:

pip install -e .
pytest -q tests/core/test_medlink.py

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very close! Thanks for the hardwork!

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make sure our documentation (docstrings) clearly documents what the exact inputs and outputs are for the model?

@Rian354 Rian354 requested a review from jhnwu3 December 24, 2025 22:32
@Logiquo Logiquo added the status: need review Pending maintainer's review label Dec 27, 2025
@Rian354
Copy link
Author

Rian354 commented Dec 27, 2025

The tests seem to initialize sampledataset w/ samples parameter, but its expecting a schema file. So I added a helper to build the dataset w/ in memory samples. Should pass tests now

Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cool that it runs, but I think there are couple things that i'm concerned about here, so please do comment on it if I'm misunderstanding some things here:

  • The EmbeddingModel changes don't seem to appear in the MedLink model itself. It would be good if we could unify them or at least otherwise, we can separate them from the EmbeddingModel, and simply assume that Glove embeddings are required for an accurate MedLink model.
  • The docstrings need to be updated with the new API using a create_sample_dataset since SampleDataset no longer accepts a samples list anymore due to our updated streaming backend
  • There seems to be a confusion around tokenizers and where to place them. For reference, you can use a Processor class if you want a specific List[str] format or if you need it to be in embedding format, you can also define a GloveProcessor. There's a couple ways of going about this. Happy to discuss more, but I think it'd be good to separate the model from the tokenizer/what we call a processor in PyHealth
  • The TaskClass upon further inspection seems to have schemas that don't follow what the docstrings say.

@@ -1,5 +1,5 @@
from typing import Set
from base import BaseTestCase
from tests.base import BaseTestCase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand this change? Is it just because of the Python versioning? Doesn't seem to lead to issues though.

>>> samples = [{"patient_id": "1", "admissions": ["ICD9_430", "ICD9_401"]}, ...]
>>> input_schema = {"admissions": "code"}
>>> output_schema = {"label": "binary"}
>>> dataset = SampleDataset(path="/some/path", samples=samples, input_schema=input_schema, output_schema=output_schema)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SampleDataset example here needs to use the create_sample_dataset() function

# Set feature_keys manually since BaseModel extracts it from dataset.input_schema
# but MedLink needs to use the provided feature_keys
self.feature_keys = feature_keys
self.feature_key = feature_keys[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the first one specifically? If need be, we can always define a medlink processor to make certain assumptions on the inputs.

For reference on processors, see here: https://pyhealth.readthedocs.io/en/latest/api/processors.html

Essentially, a tokenizer is a type of processor in our framework.

self.criterion = nn.BCEWithLogitsLoss()

def _encode_tokens(self, seqs: List[List[str]], device: torch.device):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, this here could be a processor call()

Edit: I just realized the forward() call here doesn't use _encode_tokens?

"""

task_name = "patient_linkage_mimic3"
input_schema = {"conditions": "sequence", "d_conditions": "sequence"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you construct a MedLink processor, we can also change our schemas to be something like "q_conditions" : "medlink" (you can also use the class name itself like "q_conditions" : MedLinkProcessor() if that's easier to map in code) and "d_conditions" : "medlink" if that helps since it seems like the model does expect a sequence, but it encodes it in a different way with different Glove embeddings rather than the current EmbeddingModel approach?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also make a List[str] processor if you want to assume strings or something of that sort. Right now, the schemas don't seem to mean much at all, and thus it's just code that adds to the confusion of what is being inputted and outputted here.

For instance, the doc strings don't match up with the code here.

self.tokenizer = tokenizer
self.vocab_size = tokenizer.get_vocabulary_size()

self.embedding = nn.Embedding(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this links up with the EmbeddingModel changes to use Glove vectors?

>>> from pyhealth.datasets import SampleDataset
>>> from pyhealth.models import MedLink
>>> samples = [{"patient_id": "1", "admissions": ["ICD9_430", "ICD9_401"]}, ...]
>>> input_schema = {"admissions": "code"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simply assume that only certain tasks are compatible with the model, so we don't have to explicitly specify a feature_keys argument here that makes things a little more confusing and more opaque to the user.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bounty Please see the bounty list in PyHealth Discord Server component: model Contribute a new model to PyHealth status: need review Pending maintainer's review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants